{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**SageMaker Repo:  Workflow/Airflow**:  \n",
    "https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/workflow/airflow.py\n",
    "\n",
    "**Airflow Repo:  SageMaker Operators**: \n",
    "https://github.com/apache/airflow/tree/master/airflow/providers/amazon/aws\n",
    "\n",
    "**Airflow Workshop**:\n",
    "https://www.sagemakerworkshop.com/airflow\n",
    "\n",
    "**Blog Post for Airflow Workshop**:\n",
    "https://aws.amazon.com/blogs/machine-learning/build-end-to-end-machine-learning-workflows-with-amazon-sagemaker-and-apache-airflow/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define the Directed Acyclic Graph (DAG) of Execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import airflow\n",
    "from airflow import DAG\n",
    "\n",
    "default_args = {\n",
    "    'owner': 'airflow',\n",
    "    'provide_context': True\n",
    "}\n",
    "\n",
    "dag = DAG('bert_reviews', \n",
    "          default_args=default_args,\n",
    "          schedule_interval='@once')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init = DummyOperator(\n",
    "    task_id='start',\n",
    "    dag=dag\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker Processing Job Operator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from airflow.contrib.operators.sagemaker_processing_operator import SageMakerProcessingOperator\n",
    "from sagemaker.workflow.airflow import processing_config\n",
    "\n",
    "process_config = processing_config(estimator=estimator,\n",
    "                                   inputs=input_data_s3_uri,\n",
    "                                   outputs=output_data_s3_uri)\n",
    "\n",
    "process_op = SageMakerProcessingOperator(\n",
    "    task_id='process',\n",
    "    config=process_config,\n",
    "    wait_for_completion=True,\n",
    "    dag=dag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "process_op.set_upstream(init)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker Training Job Operator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.tensorflow import TensorFlow\n",
    "\n",
    "estimator = TensorFlow(entry_point='tf_bert_reviews.py',\n",
    "                       source_dir='src',\n",
    "                       role=role,\n",
    "                       instance_count=train_instance_count,\n",
    "                       instance_type=train_instance_type,\n",
    "                       volume_size=train_volume_size,\n",
    "                       use_spot_instances=True,\n",
    "                       max_wait=7200, # Seconds to wait for spot instances to become available\n",
    "                       checkpoint_s3_uri=checkpoint_s3_uri,\n",
    "                       py_version='py3',\n",
    "                       framework_version='2.1.0',\n",
    "                       hyperparameters={'epochs': epochs,\n",
    "                                        'learning_rate': learning_rate,\n",
    "                                        'epsilon': epsilon,\n",
    "                                        'train_batch_size': train_batch_size,\n",
    "                                        'validation_batch_size': validation_batch_size,\n",
    "                                        'test_batch_size': test_batch_size,                                             \n",
    "                                        'train_steps_per_epoch': train_steps_per_epoch,\n",
    "                                        'validation_steps': validation_steps,\n",
    "                                        'test_steps': test_steps,\n",
    "                                        'use_xla': use_xla,\n",
    "                                        'use_amp': use_amp,                                             \n",
    "                                        'max_seq_length': max_seq_length,\n",
    "                                        'freeze_bert_layer': freeze_bert_layer,\n",
    "                                        'enable_sagemaker_debugger': enable_sagemaker_debugger,\n",
    "                                        'enable_checkpointing': enable_checkpointing,\n",
    "                                        'enable_tensorboard': enable_tensorboard,                                        \n",
    "                                        'run_validation': run_validation,\n",
    "                                        'run_test': run_test,\n",
    "                                        'run_sample_predictions': run_sample_predictions},\n",
    "                       input_mode=input_mode,\n",
    "                       metric_definitions=metrics_definitions,\n",
    "                       rules=rules,\n",
    "                       debugger_hook_config=hook_config,                       \n",
    "                       max_run=7200, # number of seconds\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from airflow.contrib.operators.sagemaker_training_operator import SageMakerTrainingOperator\n",
    "from sagemaker.workflow.airflow import training_config\n",
    "\n",
    "train_config = training_config(estimator=estimator,\n",
    "                               inputs=input_data_s3_uri)\n",
    "\n",
    "train_op = SageMakerTrainingOperator(\n",
    "    task_id='train',\n",
    "    config=train_config,\n",
    "    wait_for_completion=True,\n",
    "    dag=dag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_op.set_upstream(process_op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker Model Operator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from airflow.contrib.operators.sagemaker_model_operator import SageMakerModelOperator\n",
    "from sagemaker.workflow.airflow import model_config\n",
    "\n",
    "model_op = SageMakerModelOperator(\n",
    "    task_id='model',\n",
    "    config=model_config,\n",
    "    wait_for_completion=True,\n",
    "    dag=dag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_op.set_upstream(train_op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker Endpoint Operator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from airflow.contrib.operators.sagemaker_endpoint_operator import SageMakerEndpointOperator\n",
    "from sagemaker.workflow.airflow import endpoint_config, # deploy_config_from_estimator\n",
    "\n",
    "deploy_op = SageMakerEndpointOperator(\n",
    "    task_id='deploy',\n",
    "    config=endpoint_config,\n",
    "    wait_for_completion=True,\n",
    "    dag=dag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "deploy_op.set_upstream(model_op)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup DAG Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "init.set_downstream(process_op)\n",
    "processing_op.set_downstream(train_op)\n",
    "train_op.set_downstream(model_op)\n",
    "model_op.set_downstream(deploy_op)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
